import os
import json
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch_geometric.data import DataLoader
from torch_geometric.datasets import TUDataset
from preprompt import PrePrompt
from RAGraph import RAGraph
from ragraph_utils import seed_everything, process_tu_dataset

# 参数设置
datasets = ["ENZYMES", "PROTEINS", "BZR", "COX2"]  # 数据集
shots = [1, 2, 3, 4, 5]  # 不同的 shot 设置
num_tasks = 5  # 每个shot运行5个任务
batch_size = 16
lr = 0.001
ring_loss_weight = 0.1  # 可调的 ring_loss 权重
downstream_epochs = 200
test_times = 5  # 每个设置运行的次数

# 保存结果
os.makedirs("results", exist_ok=True)


# 训练与评估过程
def run_experiment(dataset_name, shot_k, task_id, feature_size, num_classes, pretrain_model):
    # 每个 task 用不同的随机种子
    seed_everything(3407 + task_id)
    dataset = TUDataset(root='data', name=dataset_name, use_node_attr=True)

    # 划分数据集
    dataset = dataset.shuffle()
    train_dataset = dataset[:int(0.5 * len(dataset))]
    val_dataset = dataset[int(0.5 * len(dataset)):int(0.8 * len(dataset))]
    test_dataset = dataset[int(0.8 * len(dataset)):]

    # 初始化 RAGraph 模型，传递 shot_k 和 feature_size
    rag_model = RAGraph(
        pretrain_model,
        resource_dataset=train_dataset,
        feature_size=feature_size,  # 传递正确的 feature_size 参数
        num_class=num_classes,
        emb_size=256,
        finetune=True,
        noise_finetune=False,
        dataset_name=dataset_name,
        shot_k=shot_k  # 传递 shot_k 参数
    ).cuda()

    # 优化器设置
    optimizer = torch.optim.Adam(rag_model.parameters(), lr=lr)
    best_loss = float('inf')
    finetune_model_name = f"modelset/finetune_rag_model_{dataset_name}_{shot_k}_task{task_id}.pkl"

    # 验证集 DataLoader
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

    task_ring_loss_log = []
    patience = 5  # 早停 patience 设置
    patience_counter = 0

    # 训练过程
    for epoch in range(downstream_epochs):
        rag_model.train()
        total_loss = 0.0
        total_ring = 0.0

        for data in tqdm(val_loader, desc=f'Task {task_id} Epoch {epoch}'):
            features, adj, labels, complex_batch, batch = process_tu_dataset(data, num_classes, feature_size)
            optimizer.zero_grad()

            # 前向传播和损失计算
            cls_loss, logits, debug_info = rag_model.forward_with_loss(
                features, adj, complex_batch=complex_batch, label=labels, batch=batch
            )

            # 获取环损失
            ring_loss = debug_info.get("ring_loss", torch.tensor(0.0, device=logits.device))

            total_loss = cls_loss + ring_loss_weight * ring_loss
            total_loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_ring += ring_loss.item()

        # 计算训练损失
        avg_loss = total_loss / len(val_loader)
        avg_ring_loss = total_ring / len(val_loader)

        # 在每个 epoch 后验证集评估模型
        rag_model.eval()
        val_loss = 0.0
        val_correct, val_total = 0, 0
        with torch.no_grad():
            for data in val_loader:
                features, adj, labels, complex_batch, batch = process_tu_dataset(data, num_classes, feature_size)
                logits = rag_model(features, adj, complex_batch=complex_batch, batch=batch)
                loss = F.cross_entropy(logits, labels)
                val_loss += loss.item()

                # 计算准确率
                preds = torch.argmax(torch.softmax(logits, dim=1), dim=1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)

        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = 100 * val_correct / val_total

        print(f"[Epoch {epoch}] train_loss = {avg_loss:.4f}, ring_loss = {avg_ring_loss:.4f}, "
              f"val_loss = {avg_val_loss:.4f}, val_accuracy = {val_accuracy:.4f}%")

        # 早停检查
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            torch.save(rag_model.state_dict(), finetune_model_name)
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"[Early Stop] Epoch {epoch} validation loss hasn't improved.")
                break

    return finetune_model_name



def main():
    accuracy_list = []
    ring_loss_history = []

    for dataset_name in datasets:
        for shot_k in shots:
            # 每个数据集和每个 shot 的多个任务
            for task_id in range(num_tasks):
                print(f"Running task {task_id + 1} for {dataset_name} dataset with {shot_k}-shot")

                # 加载预训练模型
                dataset = TUDataset(root='data', name=dataset_name, use_node_attr=True)
                feature_size = dataset.num_node_attributes
                num_classes = dataset.num_classes
                pretrain_model = PrePrompt(feature_size, 256, 'prelu', 1, 0.3, use_proj=True)
                pretrain_model.load_state_dict(torch.load(f'modelset/model_{dataset_name}.pkl'))
                pretrain_model = pretrain_model.cuda()

                finetune_model_name = run_experiment(dataset_name, shot_k, task_id, feature_size, num_classes,
                                                     pretrain_model)

                # 测试模型
                rag_model = RAGraph(
                    pretrain_model,
                    resource_dataset=dataset,
                    feture_size=feature_size,
                    num_class=num_classes,
                    emb_size=256,
                    finetune=True,
                    noise_finetune=False,
                ).cuda()

                rag_model.load_state_dict(torch.load(finetune_model_name))
                rag_model.eval()

                test_loader = DataLoader(dataset[int(0.8 * len(dataset)):], batch_size=batch_size, shuffle=False)
                correct, total = 0, 0
                for data in test_loader:
                    features, adj, labels, complex_batch, batch = process_tu_dataset(data, num_classes, feature_size)
                    logits = rag_model(features, adj, complex_batch=complex_batch, batch=batch)
                    _, preds = torch.max(torch.softmax(logits, dim=1), dim=1)
                    correct += (preds == labels).sum().item()
                    total += labels.size(0)

                acc = 100 * correct / total
                accuracy_list.append(acc)

                print(f"Accuracy for {dataset_name} with {shot_k}-shot, task {task_id + 1}: {acc:.4f}%")

    # 输出所有任务的结果
    print("-" * 100)
    print("所有任务的准确率:")
    for i, acc in enumerate(accuracy_list):
        print(f"Task {i + 1}: {acc:.4f}%")

    accs = np.array(accuracy_list)
    print(f"\nmean: {accs.mean():.4f}%")
    print(f"std: {accs.std():.4f}%")
    print("-" * 100)

    # 保存结果
    with open(f"results/finetune_rag_summary.json", "w") as f:
        json.dump({
            "mean": np.mean(accuracy_list),
            "std": np.std(accuracy_list),
            "accuracy": accuracy_list
        }, f, indent=4)

    # 绘制结果
    plt.figure()
    plt.plot(accuracy_list)
    plt.xlabel("Task")
    plt.ylabel("Accuracy (%)")
    plt.title("Accuracy over Tasks")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f"results/accuracy_curve.png")
    plt.close()


if __name__ == "__main__":
    main()
